-
Notifications
You must be signed in to change notification settings - Fork 1.9k
fix: port moe routing to new triton_kernels API #214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
3557af1 to
04ebbfd
Compare
The `triton_kernels.routing` module was deprecated and removed in triton commit 30ede52aa (triton-lang/triton#8375). Replaced deprecated `routing()` call with new primitives in `compute_routing()`. --- Upgrade to `triton>=3.5`. `triton_kernels` HEAD uses on `tl.target_info()` that is not available in 3.4.
04ebbfd to
76cf668
Compare
💡 Codex ReviewLines 37 to 55 in 04ebbfd
In ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. |
|
@codex I don't think that's correct, that doesn't match the original PR: triton-lang/triton#8375 def legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
sparse_logits = SparseMatrix(indx=expt_indx, vals=expt_scal, mask=bitmatrix)
dispatch_indx = sparse_logits.mask.metadata.col_sorted_indx
combine_indx = sparse_logits.mask.metadata.row_sorted_indx
ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask.metadata.col_sum, dispatch_indx.shape[0])
gate_scal = sparse_logits.vals.flatten()[combine_indx] ###### combine_indx, not dispatch.
routing_data = RoutingData(gate_scal, ragged_batch_metadata.batch_sizes, n_expts_tot, n_expts_act,
ragged_batch_metadata)
gather_idx = GatherIndx(combine_indx, dispatch_indx)
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)
return routing_data, gather_idx, scatter_idx
def legacy_routing(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
if sm_first:
logits = torch.softmax(logits, dim=-1)
sparse_logits = topk(logits, n_expts_act, apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows)
return legacy_routing_from_bitmatrix(sparse_logits.mask, sparse_logits.vals, sparse_logits.indx, logits.shape[-1],
n_expts_act)I think it's confused by row_sorted_indx is reference col-major positions, but are grouped by row. |
|
Note To use Codex here, create an environment for this repo. |
|
@dkundel-openai any update on this? gpt-oss is still broken on latest |
The
triton_kernels.routingmodule was deprecated and removed in triton commit 30ede52aa (triton-lang/triton#8375).Replaced deprecated
routing()call with new primitives incompute_routing().Upgrade to
triton>=3.5.triton_kernelsHEAD uses ontl.target_info()that is not available in 3.4.